from sklearn.base import BaseEstimator
import numpy as np
from typing import Dict, Tuple
import re
from typing import Dict, Tuple
import numpy as np

class CustomModel(BaseEstimator):
    def __init__(self, param_dict: Dict, n_features: int = 9, link: str = "sigmoid"):
        self.param_dict = param_dict
        self.n_features = n_features
        self.link = link
        self.beta_main, self.beta_pairs = self._parse_params(param_dict, n_features)
        
    def _parse_params(self, param_dict: Dict, n_features: int = 9) -> Tuple[np.ndarray, dict]:
        """
        Parse parameter dictionary into main effects and interaction terms (0-based indices).
        - Main: single index (int or str), e.g. 0, "3"
        - Pair: two indices (tuple/list or any string with two integers), e.g. (0,1), "12-15", "3,23"
        """
        beta_main = np.zeros(n_features, dtype=float)
        beta_pairs = {}

        for k, v in param_dict.items():
            if v is None:
                continue

            # Normalize key to a list of ints: [i] for main, [i, j] for pair
            if isinstance(k, (tuple, list)) and len(k) == 2:
                parts = list(map(int, k))
            elif isinstance(k, int):
                parts = [k]
            else:
                parts = list(map(int, re.findall(r'\d+', str(k))))

            if len(parts) == 1:
                i = parts[0]
                if 0 <= i < n_features:
                    beta_main[i] = float(v)
            elif len(parts) == 2:
                i, j = parts
                if 0 <= i < n_features and 0 <= j < n_features and i != j:
                    a, b = sorted((i, j))
                    beta_pairs[(a, b)] = float(v)
            # ignore anything else

        return beta_main, beta_pairs
    
    def fit(self, X, y):
        # Your model is already parameterized, so fit just returns self
        return self
    
    def predict(self, X, output_logits: bool = False):
        """
        Predict using the CustomModel.

        Args:
            X: Input features
            output_logits: If True, return raw logits instead of applying link function

        Returns:
            Predictions or logits depending on output_logits parameter
        """
        X = np.asarray(X, dtype=float)
        if X.ndim != 2 or X.shape[1] != self.n_features:
            raise ValueError(f"X must have shape (N, {self.n_features})")

        z = X @ self.beta_main
        if self.beta_pairs:
            pair_sum = np.zeros(X.shape[0], dtype=float)
            for (i, j), b in self.beta_pairs.items():
                pair_sum += b * X[:, i] * X[:, j]
            z = z + pair_sum

        if output_logits or self.link in (None, "identity"):
            return z
        elif self.link == "sigmoid":
            return 1.0 / (1.0 + np.exp(-z))
        else:
            raise ValueError("link must be 'sigmoid' or 'identity'")
    
    def get_params(self, deep=True):
        return {
            'param_dict': self.param_dict,
            'n_features': self.n_features,
            'link': self.link
        }
    
    def set_params(self, **params):
        for key, value in params.items():
            setattr(self, key, value)
        return self
    
    @property
    def coef_(self):
        """Return main coefficients (sklearn convention)"""
        return self.beta_main
    
    @property
    def coef_pairs_(self):
        """Return pair coefficients (sklearn convention)"""
        return self.beta_pairs
    
    @property
    def intercept_(self):
        """Return intercept (always 0 in this model)"""
        return 0.0


def generate_X_y(
    N: int,
    model: CustomModel,
    X_dist: str = "normal",  # 'normal' or 'uniform'
    random_state: int | None = None,
    feature_scale: float = 1.0,
    noise_scale: float = 0.1,  # Standard deviation of white noise
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Generate feature matrix X and labels y using CustomModel.

    For models with sigmoid link function, automatically generates binary labels (0/1)
    with a 0.5 decision boundary. For other link functions, generates continuous values
    with added white noise.

    Args:
        N: Number of samples to generate
        model: CustomModel instance to use for generating y values
        X_dist: Distribution for X features ('normal' or 'uniform')
        random_state: Random seed for reproducibility
        feature_scale: Scaling factor for features
        noise_scale: Standard deviation of Gaussian white noise added to logits (sigmoid)
                    or directly to predictions (other link functions)

    Returns:
        Tuple of (X, y) arrays where y is binary (0/1) for sigmoid link functions
        or continuous for other link functions
    """
    rng = np.random.default_rng(random_state)
    if X_dist == "normal":
        X = rng.normal(0.0, 1.0, size=(N, model.n_features)) * feature_scale
    elif X_dist == "uniform":
        X = rng.uniform(0.0, 1.0, size=(N, model.n_features)) * feature_scale
    else:
        raise ValueError("X_dist must be 'normal' or 'uniform'")

    if model.link == "sigmoid":
        # For sigmoid link function, generate binary classification labels
        # Get logits (raw linear combination before sigmoid)
        z = model.predict(X, output_logits=True)

        # Add noise to logits (more realistic than adding to probabilities)
        logit_noise = rng.normal(0.0, noise_scale, size=N)
        z_noisy = z + logit_noise

        # Apply sigmoid to get probabilities
        y_prob = 1.0 / (1.0 + np.exp(-z_noisy))

        # Convert to binary labels with 0.5 decision boundary
        y = (y_prob >= 0.5).astype(int)

    else:
        # For identity or other link functions, generate continuous values
        y_clean = model.predict(X)

        # Add white noise to each sample
        noise = rng.normal(0.0, noise_scale, size=N)
        y = y_clean + noise

    return X, y